from typing import List, Dict, Any
from typing import (
    Annotated,
    Sequence,
    TypedDict,
)
import json
import os
import asyncio
from typing import TypedDict, Optional, Literal
from utils import extract_json, encode_image, load_and_encode_image,cleanup_temp_images
from agent_tools_normalized5_v1 import *
from download_data import *
from prompt_lib4_v2 import *
from model_wrapper import Message, ChatModel, get_llm_with_tools

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Define State
class Plan(TypedDict, total=False):
    distortion_detection: bool
    tool_selection: bool

class IQAState(TypedDict):
    query: str
    query_type: str
    images: List[str]
    distortions: List[str]
    distortion_analysis: Optional[Dict[str, List[Dict[str, str]]]]
    reference_type: str
    reason: str
    required_tool: List[str]
    object_names: Optional[List[str]]
    choices_list: List[str]
    distortion_source: Optional[str]
    quality_scores: Dict[str, Dict[str, Tuple[str, float]]]
    final_response: Optional[float]
    plan: Optional[Plan]
    summary: Optional[str]
    error: Optional[str]
    messages: Optional[List[Message]]

# Define IQA Tools
selected_tool = {"TopIQ_FR_tool": TopIQ_FR_tool,"AHIQ_tool":AHIQ_tool,"FSIM_tool": FSIM_tool,"LPIPS_tool": LPIPS_tool,"DISTS_tool": DISTS_tool, "WaDIQaM_FR_tool": WaDIQaM_FR_tool,
                    "PieAPP_tool": PieAPP_tool,"MS_SSIM_tool": MS_SSIM_tool,"GMSD_tool": GMSD_tool,"SSIM_tool": SSIM_tool,"CKDN_tool": CKDN_tool,"VIF_tool": VIF_tool,
                    "PSNR_tool": PSNR_tool,"VSI_tool": VSI_tool, "QAlign_tool": QAlign_tool, "CLIPIQA_tool": CLIPIQA_tool,"UNIQUE_tool": UNIQUE_tool,"HyperIQA_tool": HyperIQA_tool,
                    "TReS_tool": TReS_tool, "WaDIQaM_NR_tool": WaDIQaM_NR_tool,"DBCNN_tool": DBCNN_tool,"ARNIQA_tool": ARNIQA_tool,"NIMA_tool": NIMA_tool,
                    "BRISQUE_tool": BRISQUE_tool, "NIQE_tool": NIQE_tool,"MANIQA_tool": MANIQA_tool,"LIQE_mix_tool": LIQE_mix_tool}        

# ======================== Planner ========================
async def planner_step(state: IQAState, model) -> IQAState:
    try:
        question = state.get('query', "")
        choices_list = state.get('choices_list', "")
        images = state.get("images", [])
        
        if isinstance(choices_list, list):
            query = question + "\n" + "\n".join(choices_list)
        else:
            query = question

        system_prompt = build_planner_prompt()
        messages = [
            Message("system", system_prompt),
            Message("user", query)
        ]

        response = await model.ainvoke(messages)
        parsed_output = extract_json(response)
        print("[Planner Output]", parsed_output)

        if parsed_output:
            state.update(parsed_output)
        else:
            state["error"] = {
                "message": "Failed to parse JSON response from Planner.",
                "image_paths": images
            }
        return state

    except Exception as e:
        print(f"[Planner Error] {e}")
        state["error"] = {
            "message": f"Planner exception: {e}",
            "image_paths": state["images"]
        }
        return state

async def planner_step_with_retry(state: IQAState, model, max_retries=2):
    for attempt in range(max_retries):
        result = await planner_step(state, model)
        if (not result.get("error")) and result.get("plan"):
            return result
        print(f"[Planner Retry] Attempt {attempt + 1} failed.")
        await asyncio.sleep(1)

    state["error"] = {
        "message": "Planner failed after retries.",
        "image_paths": state["images"]
    }
    return state

# ======================== Distortion Step ========================
async def distortion_step(state: IQAState, model):
    plan = state.get("plan", {})
    detection_required = plan.get("distortion_detection", True)
    analysis_required = plan.get("distortion_analysis", True)
    choices_list = state.get("choices_list", [])
    choice_dict  = {chr(65+i): v for i,v in enumerate(choices_list)} if choices_list else None
    user_question=(
                        f"Question: {state['query']}\n\n"
                        "Answer choices:\n" + "\n".join(f"{k}. {v}" for k,v in (choice_dict or {}).items())
                    )    
    try:
        dist_path = state["images"][0]
        ref_path = state["images"][1] if state.get("reference_type") == "Full-Reference" else None
        is_object_level = state.get("object_names") not in (None, [None], [])

        image_contents = []
        for path in [dist_path, ref_path]:
            if path:
                image_contents.append({
                    "type": "image_url",
                    "image_url": {"url": f"data:image/png;base64,{encode_image(path)}"}
                })
        if ref_path:
            img_expl = "The first image is the distorted image. The second image is the reference image."
        else:
            img_expl = "This image is the distorted image."

        # ===== Step 1: Distortion Detection =====
        if detection_required:
            print("[Distortion Detection] Running...")
            prompt = build_distortion_detection_prompt(is_object_level,user_question)
            query = (
                f"{img_expl}\n\nPlease analyze the specific regions: {state.get('object_names')} of the distorted image."
                if is_object_level else
                f"{img_expl}\n\nPlease analyze the distorted image."
            )
            messages = [Message("system", prompt), Message("user", [{"type": "text", "text": query}] + image_contents)]
            response = await model.ainvoke(messages)
            parsed = extract_json(response)
            print("[Detection Result]", parsed)

            if parsed and "distortions" in parsed and isinstance(parsed["distortions"], dict):
                state["distortions"] = parsed["distortions"]
            else:
                print("[Distortion Detection] No valid distortion output, fallback to empty.")
                state["distortions"] = {"global": []}  # fallback empty
        else:
            print("[Distortion Detection] Skipped")

        # ===== Step 2: Distortion Analysis =====
        if analysis_required:
            print("[Distortion Analysis] Running...")
            distortion_dict = state.get("distortions", {})

            all_empty = all(isinstance(v, list) and len(v) == 0 for v in distortion_dict.values())
            if all_empty:
                print("[Distortion Analysis] No distortion detected, skipping LLM analysis.")
                state["distortion_analysis"] = {
                    k: [{"type": "None", "severity": "None", "explanation": "No visible distortions detected."}]
                    for k in distortion_dict.keys()
                }
                return state
            
            prompt = build_distortion_analysis_prompt_multi_object(distortion_dict=distortion_dict, has_reference=(ref_path is not None),user_question=user_question)
            query = f"{img_expl}\n\nPlease analyze the following regions and their corresponding distortions."
            messages = [Message("system", prompt), Message("user", [{"type": "text", "text": query}] + image_contents)]
            response = await model.ainvoke(messages)
            parsed = extract_json(response)
            print("[Analysis Result]", parsed)

            if parsed and "distortion_analysis" in parsed:
                state["distortion_analysis"] = parsed["distortion_analysis"]
            else:
                state["error"] = {"message": "Failed to parse analysis result"}
                return state
        else:
            print("[Distortion Analysis] Skipped")

    except Exception as e:
        state["error"] = {"message": f"Distortion step failed: {e}"}
    return state

# ======================== Tool Selection + Execution ========================
def build_tool_call_args(tool_instance, state: IQAState) -> Dict[str, str]:
    dist_path = state["images"][0]
    ref_path = state["images"][1] if len(state["images"]) > 1 else None

    fn = getattr(tool_instance, "func", None)
    if not fn:
        raise ValueError("Invalid tool instance (no 'func' attribute)")
    arg_names = fn.__code__.co_varnames[:fn.__code__.co_argcount]

    if "reference_image" in arg_names and "distorted_image" in arg_names:
        return {
            "reference_image": ref_path,
            "distorted_image": dist_path
        }
    elif "image_url" in arg_names:
        return {
            "image_url": dist_path
        }
    else:
        raise ValueError(f"Unsupported tool signature: {arg_names}")

async def tool_selection_step(state: IQAState, model) -> IQAState:
    try:
        distortion_dict = state.get("distortions", {})
        prompt = build_tool_prompt()
        query = f"Distortions: {distortion_dict}"
        messages = [Message("system", prompt), Message("user", query)]
        response = await model.ainvoke(messages)
        parsed = extract_json(response)
        print("[Tool Selection]", parsed)

        if parsed is None:
            state["error"] = {"message": "Tool selection failed: unable to parse response"}
            return state

        if "selected_tools" in parsed:
            state["selected_tools"] = parsed["selected_tools"]
            state.pop("error", None) 
            return state

        # fallback
        state["error"] = {"message": "Tool selection failed: missing selected_tools field"}
        return state

    except Exception as e:
        state["error"] = {"message": f"Tool selection failed: {e}"}
        return state

async def tool_execution_step(state: IQAState) -> IQAState:
    quality_scores = {}
    required_tool_list = state.get("required_tool")

    # ===== Required Tool Path =====
    if required_tool_list:
        quality_scores["global"] = {}
        for tool_name in required_tool_list:
            tool_instance = selected_tool.get(tool_name)
            if not tool_instance:
                error_msg = f"[Tool Error] Required tool '{tool_name}' not found."
                print(error_msg)
                state["error"] = {"message": error_msg}
                return state
            try:
                tool_call = build_tool_call_args(tool_instance, state)
                score = tool_instance.invoke(tool_call)
                quality_scores["global"][tool_name] = score
            except Exception as e:
                error_msg = f"[Execution Error] {tool_name}: {e}"
                print(error_msg)
                state["error"] = {"message": error_msg}
                return state

        state["quality_scores"] = quality_scores
        state.pop("error", None) 
        return state

    # ===== Selected Tool Path =====
    selected = state.get("selected_tools", {})
    try:
        for obj_name, dist_map in selected.items():
            for dist_type, tool_name in dist_map.items():
                tool_name = tool_name.replace("functions.", "")
                tool_instance = selected_tool.get(tool_name)
                if not tool_instance:
                    error_msg = f"[Tool Error] Tool '{tool_name}' not found."
                    print(error_msg)
                    state["error"] = {"message": error_msg}
                    return state
                try:
                    tool_call = build_tool_call_args(tool_instance, state)
                    tool_call["distortion"] = dist_type
                    score = tool_instance.invoke(tool_call)
                    quality_scores.setdefault(obj_name, {})[dist_type] = (tool_name, score)
                except Exception as e:
                    error_msg = f"[Execution Error] {tool_name} for {obj_name}/{dist_type}: {e}"
                    print(error_msg)
                    state["error"] = {"message": error_msg}
                    return state

        state["quality_scores"] = quality_scores
        state.pop("error", None) 
        return state
    except Exception as e:
        state["error"] = {"message": f"Tool execution failed: {e}"}
        return state

async def tool_step(state: IQAState):
    plan = state.get("plan", {})
    do_selection = plan.get("tool_selection", True)
    do_execute = plan.get("tool_execute", True)
    
    if state.get("distortions") == {"global": []}:
        print("[Tool Step] Skipped due to no visible distortions.")
        state["selected_tools"] = None
        state["quality_scores"] = {"global": {"default": 5.0}} 
        state.pop("error", None)
        return state    
    
    ref_path = state["images"][1] if len(state["images"]) > 1 else None
    model = get_llm_with_tools(ref_path)
   
    # Tool Selection
    if do_selection:
        state = await tool_selection_step(state, model)
        if state.get("error"):
            return state
    else:
        print("[Tool Selection] Skipped")

    # Tool Execution
    if do_execute:
        state = await tool_execution_step(state)
        if state.get("error"):
            return state
    else:
        print("[Tool Execution] Skipped")
    state.pop("error", None) 
    return state

# ======================== Summarizer ========================
ABC_CHOICES = ['A', 'B', 'C', 'D', 'E']
ABC_WEIGHTS = [5, 4, 3, 2, 1]
def extract_choice_logprobs(top_logprobs) -> Dict[str, float]:
    logprobs_dict = {}

    for choice in ABC_CHOICES:
        matching_items = [
            item for item in top_logprobs
            if item.token.strip().startswith(choice)
        ]
        if matching_items:
            best = max(matching_items, key=lambda x: x.logprob)
            logprobs_dict[choice] = best.logprob
        else:
            logprobs_dict[choice] = -100.0  # default logprob for missing

    return logprobs_dict

def extract_probs_from_logprobs(logprobs_dict):
    logprobs = [logprobs_dict.get(k, -100.0) for k in ['A', 'B', 'C', 'D', 'E']]
    exp_logprobs = np.exp(logprobs)
    probs = exp_logprobs / np.sum(exp_logprobs)
    return probs[::-1]  # E→A → 1~5

def compute_hvs_weights(avg_q, num_levels=5, b=1.0):
    i = np.arange(1, num_levels + 1)
    w = np.exp(-b * (avg_q - i) ** 2)
    return w / np.sum(w)

def compute_hvs_final_score(logprobs_dict, avg_q, b=1.0):
    probs = extract_probs_from_logprobs(logprobs_dict)
    alpha = compute_hvs_weights(avg_q, num_levels=5, b=b)
    v_i = np.sum(alpha * np.arange(1, 6))  # weighted center
    return float(np.sum(probs * v_i))  

def format_distortion_text(detected: Dict[str, List], analyzed: Optional[Dict[str, List]] = None) -> str:
    result_lines = []
    for obj, det_list in detected.items():
        result_lines.append(f"{obj}:")
        ana_list = analyzed.get(obj, []) if analyzed else []
        if ana_list and isinstance(ana_list[0], dict):
            for d in ana_list:
                result_lines.append(f"- {d['type']} ({d['severity']}): {d['explanation']}")
        elif det_list:
            for d in det_list:
                result_lines.append(f"- {d}")
        else:
            result_lines.append("No visible distortions.")
    return "\n".join(result_lines)  
def convert_score_to_quality(score: float) -> str:
        if score >= 4.0:
            return "Good"
        elif score >= 2.5:
            return "Moderate"
        else:
            return "Poor" 
        
def summarize_query_prompt(state: IQAState, choice_dict: Dict[str, str]) -> str:
    query_info = state.get("query", "")
    object_names = state.get("object_names", [])
    quality_scores = state.get("quality_scores", {})
    reference_type = state.get("reference_type", "No-Reference")

    if reference_type == "Full-Reference":
        ref_text = "Note: The first image is the distorted image, and the second is the reference image.\n\n"
    else:
        ref_text = ""

    # Distortion Text
    distortion_text = format_distortion_text(
        detected=state.get("distortions", {}),
        analyzed=state.get("distortion_analysis", {})
    )

    # # Tool Response Text
    # tool_text = ""
    # if quality_scores:
    #     for obj_name, dist_map in quality_scores.items():
    #         tool_text += f"{obj_name}:\n"
    #         for dist_type, value in dist_map.items():
    #             if isinstance(value, tuple):  # (tool_name, score)
    #                 score = float(value[1])
    #             elif isinstance(value, (int, float)):
    #                 score = float(value)
    #             else:
    #                 score = None
    #             if score is not None:
    #                 tool_text += f"- {dist_type}: {score:.2f}\n"
    #             else:
    #                 tool_text += f"- {dist_type}: [Invalid score]\n"
    
    # Tool Response Text
    tool_text = ""
    if quality_scores:
        for obj_name, dist_map in quality_scores.items():
            tool_text += f"{obj_name}:\n"
            for dist_type, value in dist_map.items():
                if isinstance(value, tuple):
                    score = float(value[1])
                elif isinstance(value, (int, float)):
                    score = float(value)
                else:
                    score = None
                if score is not None:
                    # quality = convert_score_to_quality(score)
                    tool_text += f"- {dist_type}: {score:.2f}\n"
                else:
                    tool_text += f"- {dist_type}: [Invalid score]\n"

    # === Local + Choice ===
    if object_names and choice_dict:
        return (
            ref_text +
            f"Question: {query_info}\n\n"
            f"Answer choices:\n" +
            "\n".join([f"{k}. {v}" for k, v in choice_dict.items()]) + "\n\n"
            f"Distortion analysis:\n{distortion_text}\n\n"
        )

    # === Global + Choice ===
    elif not object_names and choice_dict:
        return (
            ref_text +
            f"Question: {query_info}\n\n"
            f"Answer choices:\n" +
            "\n".join([f"{k}. {v}" for k, v in choice_dict.items()]) + "\n\n"
            f"Tool response:\n{tool_text} (where 1=Bad, 2=Poor, 3=Fair, 4=Good, 5=Excellent)\n\n"
            f"Distortion analysis:\n{distortion_text}"
        )

    # === Global + No Choice ===
    elif not object_names and not choice_dict:
        return (
            ref_text +
            f"Question: {query_info}\n\n"
            f"Answer choices:\nA. Excellent\nB. Good\nC. Fair\nD. Poor\nE. Bad\n\n"
            f"Tool response:\n{tool_text} (where 1=Bad, 2=Poor, 3=Fair, 4=Good, 5=Excellent)\n\n"
            f"Distortion analysis:\n{distortion_text}"
        )

    # === Local + No Choice ===
    else:
        return "Unsupported: local object + open-ended query."

async def summarize_quality_step(state: IQAState, model) -> IQAState:
    MAX_RETRIES = 2
    for attempt in range(MAX_RETRIES):
        try:
            query_type = state.get("query_type", "IQA")
            object_names = state.get("object_names", [])
            choices_list = state.get("choices_list", [])
            quality_scores = state.get("quality_scores", {})
            choice_dict = {chr(65 + i): v for i, v in enumerate(choices_list)} if choices_list else {}
            user_question=(
                                f"Question: {state['query']}\n\n"
                                "Answer choices:\n" + "\n".join(f"{k}. {v}" for k,v in (choice_dict or {}).items())
                            )  
            if (object_names and not choice_dict) or (query_type == "others" and not choice_dict):
                state["final_response"] = "Unsupported: local/others + open-ended query."
                return state
            
            # ============ Build prompt ============
            if query_type == "others":
                prompt = build_summary_others_prompt(choices_list)
                query_prompt = f"Question: {state.get('query', '')}\n\nAnswer choices:\n" + "\n".join([f"{k}. {v}" for k, v in choice_dict.items()])
            else:
                query_prompt = summarize_query_prompt(state, choice_dict)
                prompt = build_summary_choice_prompt(choices_list,user_question)

            if "Unsupported" in query_prompt:
                state["final_response"] = query_prompt
                return state
        
            # ============ Prepare image input ============
            dist_path = state["images"][0]
            ref_path = state["images"][1] if state.get("reference_type") == "Full-Reference" else None

            image_contents = []
            for path in [dist_path, ref_path]:
                if path:
                    image_contents.append({
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{encode_image(path)}"}
                    })

            # ============ others or Choice Summary ============
            if query_type == "others" or choice_dict:
                messages = [Message("system", prompt), Message("user", [{"type": "text", "text": query_prompt}] + image_contents)]
                response = await model.ainvoke(messages)
                parsed = extract_json(response)
                state["final_response"] = parsed["final_response"]
                state["reason"] = parsed.get("reason", "")
                print("[Summary Final Response]", state["final_response"])
                state.pop("error", None)
                return state      

            # ============ Global + No Choice (numeric score) ============
            if not object_names and not choice_dict:
                prompt = build_summary_nochoice_prompt()
                messages = [Message("system", prompt), Message("user", query_prompt)]
                response = await model.ainvoke_full(messages, logprobs=True, top_logprobs=5)

                top_logprobs = response.choices[0].logprobs.content[0].top_logprobs
                logprobs_dict = extract_choice_logprobs(top_logprobs)
                for ch in ABC_CHOICES:
                    logprobs_dict.setdefault(ch, -100.0)

                scores = []
                for dist_map in quality_scores.values():
                    for v in dist_map.values():
                        if isinstance(v, tuple):
                            scores.append(float(v[1]))
                        elif isinstance(v, (int, float)):
                            scores.append(float(v))
                avg_q = np.mean(scores)
                score = compute_hvs_final_score(logprobs_dict, avg_q)
                state["final_response"] = score
                print("[Summary Final Score]", score)
                state.pop("error", None)
                return state

        except Exception as e:
            state["error"] = {"message": f"Summarizer failed: {e}"}
            await asyncio.sleep(1)

    state["error"] = {"message": "Summarizer failed after retries"}
    return state

# # Async Function to Run IQA Analysis
async def run_iqa_analysis(question, images,choices_list=None):
    """Runs the complete IQA analysis using LangGraph."""
    state = IQAState(
        query=question,
        query_type=None,
        images=images,
        choices_list=choices_list or [],
        distortions=None,
        reference_type="Full-Reference" if len(images) > 1 else "No-Reference",
        required_tool=None,
        object_names=None,
        distortion_source=None,
        quality_scores=None,
        final_response=None,
        plan=None,
        error=None,
        reason=None,
        messages=[]
    )
    model = ChatModel(model="gpt-4o", temperature=0)

    # step 1: planner
    await planner_step_with_retry(state, model)
    if state.get("error"): return state
    
    # step 2: distortion detection & analyze
    state = await distortion_step(state, model)
    if state.get("error"): return state

    # step 3: tool selection & executor
    state = await tool_step(state)
    if state.get("error"): return state

    # step 4: summarizer
    state = await summarize_quality_step(state, model)
    return state

async def run_iqa_analysis_with_retry(question, images, choices_list=None, max_retries=2):
    for attempt in range(max_retries):
        state = await run_iqa_analysis(question, images, choices_list)
        if not state.get("error"):
            return state
        print(f"[Retry {attempt+1}] Error encountered: {state['error']['message']}")
    print("[Final Failure] All retries failed.")
    return state

# Run IQA Analysis
async def main():
    output_path = "/root/IQA/IQA-Agent/results/gpt-4o/final_v1/results/qbench_gpt_test1.json"
    errors_path = "/root/IQA/IQA-Agent/results/gpt-4o/final_v1/results/qbench_gpt_test1_error.json"

    with open(output_path, "w", encoding="utf-8") as f:
        f.write("[\n")
    with open(errors_path, "w", encoding="utf-8") as d:
        d.write("[\n")


    img_paths, questions, choices_list, correct_choices,types,concerns = QBench_data()
    # img_paths, questions, choices_list, correct_choices = QBench_error_recovery_data()
    for i, (img_path, question, choices, correct,type, concern) in enumerate(zip(img_paths, questions, choices_list, correct_choices,types, concerns)):
        print(f"[{i}] Running analysis...")
        output_state = await run_iqa_analysis_with_retry(question, img_path, choices)
        result = {
            "correct_choice": correct,
            "type,":type,
            "concern":concern,
            "state": output_state
        }       
        if "messages" in result["state"]:
            result["state"].pop("messages")

        if output_state.get("error"):
            output_state["error"]["image_paths"] = img_path
            output_state["error"]["query"] = question
            with open(errors_path, "a", encoding="utf-8") as d:
                json.dump(output_state["error"], d, ensure_ascii=False, indent=2)
                d.write(",\n" if i != len(img_paths) - 1 else "\n")
        else:
            with open(output_path, "a", encoding="utf-8") as f:
                json.dump(result, f, ensure_ascii=False, indent=2)
                f.write(",\n" if i != len(img_paths) - 1 else "\n")
        cleanup_temp_images()
    with open(output_path, "a", encoding="utf-8") as f:
        f.write("]\n")
    with open(errors_path, "a", encoding="utf-8") as d:
        d.write("]\n")   


if __name__ == "__main__":
    asyncio.run(main())